#!/usr/bin/env python3
# G19_v2.1 — Mesh Certification (h vs h/2), aligned physical bins + amplitude density (per Δr)
# Control: present-act, boolean/ordinal. Deterministic DDA 1/r per shell; NO curves/weights/RNG in control.
# Two meshes: coarse N and fine 2N (h vs h/2). Both use the SAME physical radial windows and bin edges:
#  - Log-bins: identical physical edges; fine uses the same edges (scaled in shell units).
#  - Linear equal-Δr bins: coarse Δr = W shells; fine Δr = 2W shells so bin counts are equal.
# Readouts (diagnostics only): slope (mid-60% log fit), plateau CV (outer fraction), amplitude *density*
# (bin-sum divided by Δr), and cross-mesh correlation computed on the SAME binned linear profile (outer).
# Acceptance checks per-mesh hygiene + cross-mesh invariance (Δslope, ΔCV, amp ratio, corr).

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple


def utc_ts() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")


def ensure_dirs(root: str, subs: List[str]) -> None:
    for s in subs:
        os.makedirs(os.path.join(root, s), exist_ok=True)


def wtxt(path: str, txt: str) -> None:
    with open(path, "w", encoding="utf-8") as f:
        f.write(txt)


def jdump(path: str, obj: dict) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)


def sha256_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""):
            h.update(chunk)
    return h.hexdigest()


def isqrt(n: int) -> int:
    return int(math.isqrt(n))


def edge_radius(cx: int, cy: int, N: int) -> int:
    return min(cx, cy, (N-1)-cx, (N-1)-cy)


def build_shell_counts(N: int, cx: int, cy: int) -> Dict[int, int]:
    shells: Dict[int, int] = {}
    for y in range(N):
        for x in range(N):
            r = isqrt((x - cx) * (x - cx) + (y - cy) * (y - cy))
            shells[r] = shells.get(r, 0) + 1
    return shells


def simulate_dda(shells: Dict[int, int], H: int, rate_num: int) -> Dict[int, int]:
    A = {r: 0 for r in shells.keys()}
    F = {r: 0 for r in shells.keys()}
    for _ in range(H):
        for r in shells.keys():
            if r == 0:
                continue
            A[r] += rate_num
            if A[r] >= r:
                F[r] += 1
                A[r] -= r
    return F


def build_log_edges(r_lo: int, r_hi: int, n_bins: int) -> List[float]:
    log_lo, log_hi = math.log(max(1, r_lo)), math.log(max(1, r_hi))
    return [math.exp(log_lo + (log_hi - log_lo) * i / n_bins) for i in range(n_bins + 1)]


def shells_in_range_int(rs: List[int], lo: float, hi: float) -> List[int]:
    lo_i = math.ceil(lo)
    hi_i = math.floor(hi)
    return [r for r in rs if lo_i <= r <= hi_i]


def slope_from_logbins(shells: Dict[int, int], fires: Dict[int, int], H: int,
                       r_min: int, r_max: int, log_edges: List[float], fit_mid_frac: float):
    rs = [r for r in sorted(shells.keys()) if r_min <= r <= r_max and r > 0]
    X, Y = [], []
    for i in range(len(log_edges) - 1):
        lo, hi = log_edges[i], log_edges[i + 1]
        arr = shells_in_range_int(rs, lo, hi)
        if not arr:
            continue
        rates = [(fires.get(r, 0) / H) for r in arr]
        r_rep = math.exp((math.log(arr[0]) + math.log(arr[-1])) / 2.0)
        X.append(math.log((rs[-1]) / r_rep + 1e-12))
        Y.append(math.log(sum(rates) / len(rates) + 1e-12))
    k = len(X)
    if k < 4:
        return float("nan"), float("nan"), k
    m = int(round(k * (1.0 - fit_mid_frac) / 2.0))
    useX = X[m:k - m] if k - 2 * m >= 2 else X
    useY = Y[m:k - m] if k - 2 * m >= 2 else Y
    xb = sum(useX) / len(useX)
    yb = sum(useY) / len(useY)
    num = sum((x - xb) * (y - yb) for x, y in zip(useX, useY))
    den = sum((x - xb) * (x - xb) for x in useX)
    if den == 0:
        return float("nan"), float("nan"), k
    b = num / den
    a = yb - b * xb
    ss_tot = sum((y - yb) * (y - yb) for y in useY)
    ss_res = sum((y - (a + b * x)) * (y - (a + b * x)) for x, y in zip(useX, useY))
    r2 = 1.0 - (ss_res / ss_tot if ss_tot > 0 else 0.0)
    return b, r2, k


def build_linear_bins(r_min: int, r_max: int, width: int) -> List[tuple]:
    stop = r_max - ((r_max - r_min + 1) % width)
    bins = []
    r = r_min
    while r + width - 1 <= stop:
        bins.append((r, r + width - 1))
        r += width
    return bins


def plateau_from_bins(shells: Dict[int, int], fires: Dict[int, int], H: int,
                      bins: List[tuple], outer_frac: float):
    vals = []
    for lo, hi in bins:
        v = 0.0
        for rr in range(lo, hi + 1):
            if rr in shells:
                v += shells[rr] * (fires.get(rr, 0) / H)
        width = (hi - lo + 1)
        vals.append(v / max(1, width))  # amplitude density
    if not vals:
        return float("nan"), float("nan"), 0, []
    k = len(vals)
    take = max(1, int(round(k * outer_frac)))
    outer = vals[-take:]
    mu = sum(outer) / len(outer)
    if mu == 0.0:
        return float("inf"), 0.0, k, vals
    s2 = sum((v - mu) * (v - mu) for v in outer) / len(outer)
    cv = math.sqrt(s2) / mu
    return cv, mu, k, vals


def run_panel(N: int, cx: int, cy: int, H: int, rate: int,
              r_min_slope: int, r_min_plat: int,
              r_max_glob: int,
              log_edges_phys: List[float],
              lin_bins_coarse: List[tuple],
              is_fine: bool) -> dict:
    if is_fine:
        N_use, CX, CY, scale = 2 * N, 2 * cx, 2 * cy, 2
    else:
        N_use, CX, CY, scale = N, cx, cy, 1
    shells = build_shell_counts(N_use, CX, CY)
    fires = simulate_dda(shells, H, rate)
    slope, r2, klog = slope_from_logbins(shells, fires, H,
                                         r_min_slope * scale, r_max_glob * scale,
                                         [e * scale for e in log_edges_phys],
                                         fit_mid_frac=0.60)
    lin_bins = [(lo * 2, hi * 2 + 1) for (lo, hi) in lin_bins_coarse] if is_fine else lin_bins_coarse
    cv, amp_density, klin, prof_density = plateau_from_bins(shells, fires, H, lin_bins, outer_frac=0.75)
    return {
        "slope": slope,
        "r2": r2,
        "cv": cv,
        "amp": amp_density,
        "klog": klog,
        "klin": klin,
        "profile": prof_density,
    }



def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config", "outputs/metrics", "outputs/audits", "outputs/run_info", "logs"])

    with open(args.manifest, "r", encoding="utf-8") as f:
        M = json.load(f)
    mpath = os.path.join(root, "config", "manifest_g19_v2_1.json")
    jdump(mpath, M)

    with open(os.path.join(root, "logs", "env.txt"), "w", encoding="utf-8") as f:
        f.write("\n".join(
            [
                f"utc={utc_ts()}",
                f"os={os.name}",
                f"cwd={os.getcwd()}",
                f"python={sys.version.split()[0]}",
            ]
        ))

    N = int(M["grid"]["N"])
    cx = int(M["grid"].get("cx", N // 2))
    cy = int(M["grid"].get("cy", N // 2))
    N_fine = int(M["mesh"]["N_fine"])
    H = int(M["H"])
    rate = int(M["rate_num"])
    omarg = int(M["outer_margin"])

    slope_cfg = M["slope"]
    plat_cfg = M["plateau"]
    r_min_slope = int(slope_cfg.get("r_min", 8))
    n_log = int(slope_cfg.get("n_log_bins", 12))
    r_min_plat = int(plat_cfg["r_min"])
    W_coarse = int(plat_cfg["shells_per_bin"])
    outer_frac = float(plat_cfg.get("outer_frac", 0.75))

    R_edge_c = edge_radius(cx, cy, N)
    R_edge_f = edge_radius(cx * 2, cy * 2, N_fine)
    r_max_glob = min(R_edge_c, R_edge_f) - omarg
    if r_max_glob <= r_min_plat + W_coarse:
        raise RuntimeError("r_max_glob too small; adjust margins.")

    log_edges = build_log_edges(r_min_slope, r_max_glob, n_log)
    lin_bins_coarse = build_linear_bins(r_min_plat, r_max_glob, W_coarse)

    coarse = run_panel(N, cx, cy, H, rate, r_min_slope, r_min_plat, r_max_glob, log_edges, lin_bins_coarse, is_fine=False)
    fine = run_panel(N, cx, cy, H, rate, r_min_slope, r_min_plat, r_max_glob, log_edges, lin_bins_coarse, is_fine=True)

    if coarse["klin"] != fine["klin"]:
        K = min(coarse["klin"], fine["klin"])
        pc = coarse["profile"][:K]
        pf = fine["profile"][:K]
    else:
        pc, pf = coarse["profile"], fine["profile"]

    take = max(1, int(round(len(pc) * outer_frac)))
    pc_outer, pf_outer = pc[-take:], pf[-take:]
    L = min(len(pc_outer), len(pf_outer))
    ax, bx = pc_outer[:L], pf_outer[:L]
    ma, mb = sum(ax) / L, sum(bx) / L
    num = sum((x - ma) * (y - mb) for x, y in zip(ax, bx))
    denx = (sum((x - ma) * (x - ma) for x in ax)) ** 0.5
    deny = (sum((y - mb) * (y - mb) for y in bx)) ** 0.5
    corr = num / (denx * deny) if denx > 0 and deny > 0 else float("nan")

    dslope = abs(coarse["slope"] - fine["slope"])
    dcv = abs(coarse["cv"] - fine["cv"])
    amp_ratio = (fine["amp"] / coarse["amp"]) if (coarse["amp"] > 0) else float("inf")

    mcsv = os.path.join(root, "outputs/metrics", "g19_v2_1_mesh_metrics.csv")
    with open(mcsv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(
            [
                "panel",
                "slope",
                "r2",
                "cv",
                "amp_density",
                "log_bins",
                "lin_nbins",
                "r_max_glob",
                "outer_bins_used",
            ]
        )
        w.writerow(
            [
                "coarse",
                f"{coarse['slope']:.6f}",
                f"{coarse['r2']:.6f}",
                f"{coarse['cv']:.6f}",
                f"{coarse['amp']:.6f}",
                coarse["klog"],
                coarse["klin"],
                r_max_glob,
                take,
            ]
        )
        w.writerow(
            [
                "fine",
                f"{fine['slope']:.6f}",
                f"{fine['r2']:.6f}",
                f"{fine['cv']:.6f}",
                f"{fine['amp']:.6f}",
                fine["klog"],
                fine["klin"],
                r_max_glob * 2,
                take,
            ]
        )

    acc = M["acceptance"]
    perC_ok = (
        abs(coarse["slope"] - acc["slope_target"]) <= acc["slope_tol_abs"]
        and coarse["r2"] >= acc["r2_min"]
        and coarse["cv"] <= acc["cv_max"]
    )
    perF_ok = (
        abs(fine["slope"] - acc["slope_target"]) <= acc["slope_tol_abs"]
        and fine["r2"] >= acc["r2_min"]
        and fine["cv"] <= acc["cv_max"]
    )
    inv_ok = (dslope <= acc["delta_slope_max"]) and (dcv <= acc["delta_cv_max"])
    amp_ok = (abs(amp_ratio - 1.0) <= acc["amp_rel_tol"])
    corr_ok = (corr >= acc["corr_min"])
    passed = bool(perC_ok and perF_ok and inv_ok and amp_ok and corr_ok)

    audit = {
        "sim": "G19_mesh_cert_v2_1",
        "coarse": {
            "slope": coarse["slope"],
            "r2": coarse["r2"],
            "cv": coarse["cv"],
            "amp_density": coarse["amp"],
        },
        "fine": {
            "slope": fine["slope"],
            "r2": fine["r2"],
            "cv": fine["cv"],
            "amp_density": fine["amp"],
        },
        "cross": {
            "delta_slope": dslope,
            "delta_cv": dcv,
            "amp_ratio_fine_over_coarse": amp_ratio,
            "corr_profile_outer": corr,
        },
        "accept": acc,
        "pass": passed,
    }
    with open(
        os.path.join(root, "outputs/audits", "g19_audit.json"), "w", encoding="utf-8"
    ) as f:
        json.dump(audit, f, indent=2, sort_keys=True)

    result_line = (
        "G19_v2_1 PASS={p} slope_c={sc:.4f} slope_f={sf:.4f} Δslope={ds:.4f} "
        "cv_c={cc:.4f} cv_f={cf:.4f} Δcv={dc:.4f} amp_ratio={ar:.3f} corr={cr:.4f}"
    ).format(
        p=passed,
        sc=coarse["slope"],
        sf=fine["slope"],
        ds=dslope,
        cc=coarse["cv"],
        cf=fine["cv"],
        dc=dcv,
        ar=amp_ratio,
        cr=corr,
    )
    wtxt(os.path.join(root, "outputs/run_info", "result_line.txt"), result_line)
    print(result_line)


if __name__ == "__main__":
    main()
